﻿// Copyright (c) Files Community
// Licensed under the MIT License.

namespace Files.Core.SourceGenerator.Generators
{
	[Generator(LanguageNames.CSharp)]
	internal class VTableFunctionGenerator : IIncrementalGenerator
	{
		public void Initialize(IncrementalGeneratorInitializationContext context)
		{
			var sources = context.SyntaxProvider.ForAttributeWithMetadataName(
				"Files.Shared.Attributes.GeneratedVTableFunctionAttribute",
				static (node, token) =>
				{
					token.ThrowIfCancellationRequested();

					// Check if the method has partial modifier and is public or internal (and not static)
					if (node is not MethodDeclarationSyntax { AttributeLists.Count: > 0 } method ||
						!method.Modifiers.Any(SyntaxKind.PartialKeyword) ||
						!(method.Modifiers.Any(SyntaxKind.PublicKeyword) || method.Modifiers.Any(SyntaxKind.InternalKeyword)) ||
						method.Modifiers.Any(SyntaxKind.StaticKeyword))
						return false;

					// Check if the type containing the method has partial modifier and is a struct
					if (node.Parent is not TypeDeclarationSyntax { Keyword.RawKind: (int)SyntaxKind.StructKeyword, Modifiers: { } modifiers } ||
						!modifiers.Any(SyntaxKind.PartialKeyword))
						return false;

					return true;
				},
				static (context, token) =>
				{
					token.ThrowIfCancellationRequested();

					var fullyQualifiedParentTypeName = context.TargetSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
					var structNamespace = context.TargetSymbol.ContainingType.ContainingNamespace.ToString();
					var structName = context.TargetSymbol.ContainingType.Name;
					var methodSymbol = (IMethodSymbol)context.TargetSymbol;
					var isReturnTypeVoid = methodSymbol.ReturnsVoid;
					var functionName = methodSymbol.Name;
					var returnTypeName = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
					var parameters = methodSymbol.Parameters.Select(x => new ParameterTypeNamePair(x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), x.Name));
					var index = (int)context.Attributes[0].NamedArguments.FirstOrDefault(x => x.Key.Equals("Index")).Value.Value!;

					return new VTableFunctionInfo(fullyQualifiedParentTypeName, structNamespace, structName, isReturnTypeVoid, functionName, returnTypeName, index, new(parameters.ToImmutableArray()));
				})
				.Where(static item => item is not null)
				.Collect()
				.Select((items, token) =>
				{
					token.ThrowIfCancellationRequested();

					return items.GroupBy(source => source.FullyQualifiedParentTypeName, StringComparer.OrdinalIgnoreCase).ToImmutableArray();
				});


			context.RegisterSourceOutput(sources, (context, sources) =>
			{
				foreach (var source in sources)
				{
					var fileName = $"{source.ToImmutableArray().ElementAt(0).ParentTypeNamespace}.{source.ToImmutableArray().ElementAt(0).ParentTypeName}_VTableFunctions.g.cs";
					var generatedCSharpCode = GenerateVtableFunctionsForStruct(source.ToImmutableArray());

					context.AddSource(fileName, generatedCSharpCode);
				}
			});
		}

		private string GenerateVtableFunctionsForStruct(ImmutableArray<VTableFunctionInfo> sources)
		{
			StringBuilder builder = new();

			builder.AppendLine($"// <auto-generated/>");
			builder.AppendLine();
			builder.AppendLine($"using global::System.Runtime.CompilerServices;");
			builder.AppendLine();
			builder.AppendLine($"#pragma warning disable");
			builder.AppendLine();

			builder.AppendLine($"namespace {sources.ElementAt(0).ParentTypeNamespace};");
			builder.AppendLine();

			builder.AppendLine($"public unsafe partial struct {sources.ElementAt(0).ParentTypeName}");
			builder.AppendLine($"{{");

			builder.AppendLine($"	private void** lpVtbl;");
			builder.AppendLine();

			var sourceIndex = 0;
			var sourceCount = sources.Count();

			foreach (var source in sources)
			{
				var returnTypeName = source.IsReturnTypeVoid ? "void" : "int";

				builder.AppendLine($"	[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]");

				builder.AppendLine($"	public partial {source.ReturnTypeName} {source.Name}({string.Join(", ", source.Parameters.Select(x => $"{x.FullyQualifiedTypeName} {x.ValueName}"))})");
				builder.AppendLine($"	{{");
				builder.AppendLine($"		return ({source.ReturnTypeName})((delegate* unmanaged[MemberFunction]<{sources.ElementAt(0).FullyQualifiedParentTypeName}*, {string.Join(", ", source.Parameters.Select(x => $"{x.FullyQualifiedTypeName}"))}, {returnTypeName}>)(lpVtbl[{source.Index}]))");
				builder.AppendLine($"			(({sources.ElementAt(0).FullyQualifiedParentTypeName}*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), {string.Join(", ", source.Parameters.Select(x => $"{x.ValueName}"))});");
				builder.AppendLine($"	}}");

				if (sourceIndex < sourceCount - 1)
					builder.AppendLine();

				sourceIndex++;
			}

			builder.AppendLine($"}}");

			return builder.ToString();
		}
	}
}
